{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# REINFORCE in PyTorch\n",
    "\n",
    "Just like we did before for Q-learning, this time we'll design a PyTorch network to learn `CartPole-v0` via policy gradient (REINFORCE).\n",
    "\n",
    "Most of the code in this notebook is taken from approximate Q-learning, so you'll find it more or less familiar and even simpler."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n",
    "    !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n",
    "    !pip install -q gymnasium\n",
    "    !pip install moviepy\n",
    "    !apt install ffmpeg\n",
    "    !pip install imageio-ffmpeg\n",
    "    !touch .setup_complete\n",
    "\n",
    "# This code creates a virtual display to draw game images on.\n",
    "# It will have no effect if your machine has a monitor.\n",
    "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n",
    "    !bash ../xvfb start\n",
    "    os.environ['DISPLAY'] = ':1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gymnasium as gym\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# also you need to install ffmpeg if not installed\n",
    "# for MacOS: ! brew install ffmpeg"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A caveat: with some versions of `pyglet`, the following cell may crash with `NameError: name 'base' is not defined`. The corresponding bug report is [here](https://github.com/pyglet/pyglet/issues/134). If you see this error, try restarting the kernel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make(\"CartPole-v1\", render_mode=\"rgb_array\")\n",
    "\n",
    "# gym compatibility: unwrap TimeLimit\n",
    "if hasattr(env, '_max_episode_steps'):\n",
    "    env = env.env\n",
    "\n",
    "env.reset()\n",
    "n_actions = env.action_space.n\n",
    "state_dim = env.observation_space.shape\n",
    "\n",
    "plt.imshow(env.render())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Building the network for REINFORCE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For REINFORCE algorithm, we'll need a model that predicts action probabilities given states.\n",
    "\n",
    "For numerical stability, please __do not include the softmax layer into your network architecture__.\n",
    "We'll use softmax or log-softmax where appropriate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build a simple neural network that predicts policy logits. \n",
    "# Keep it simple: CartPole isn't worth deep architectures.\n",
    "model = nn.Sequential(\n",
    "  <YOUR CODE: define a neural network that predicts policy logits>\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Predict function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note: output value of this function is not a torch tensor, it's a numpy array.\n",
    "So, here gradient calculation is not needed.\n",
    "<br>\n",
    "Use [no_grad](https://pytorch.org/docs/stable/autograd.html#torch.autograd.no_grad)\n",
    "to suppress gradient calculation.\n",
    "<br>\n",
    "Also, `.detach()` (or legacy `.data` property) can be used instead, but there is a difference:\n",
    "<br>\n",
    "With `.detach()` computational graph is built but then disconnected from a particular tensor,\n",
    "so `.detach()` should be used if that graph is needed for backprop via some other (not detached) tensor;\n",
    "<br>\n",
    "In contrast, no graph is built by any operation in `no_grad()` context, thus it's preferable here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_probs(states):\n",
    "    \"\"\" \n",
    "    Predict action probabilities given states.\n",
    "    :param states: numpy array of shape [batch, state_shape]\n",
    "    :returns: numpy array of shape [batch, n_actions]\n",
    "    \"\"\"\n",
    "    # convert states, compute logits, use softmax to get probability\n",
    "    <YOUR CODE>\n",
    "    return <YOUR CODE>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_states = np.array([env.reset()[0] for _ in range(5)])\n",
    "test_probas = predict_probs(test_states)\n",
    "assert isinstance(test_probas, np.ndarray), \\\n",
    "    \"you must return np array and not %s\" % type(test_probas)\n",
    "assert tuple(test_probas.shape) == (test_states.shape[0], env.action_space.n), \\\n",
    "    \"wrong output shape: %s\" % np.shape(test_probas)\n",
    "assert np.allclose(np.sum(test_probas, axis=1), 1), \"probabilities do not sum to 1\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Play the game\n",
    "\n",
    "We can now use our newly built agent to play the game."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_session(env, t_max=1000):\n",
    "    \"\"\"\n",
    "    Play a full session with REINFORCE agent.\n",
    "    Returns sequences of states, actions, and rewards.\n",
    "    \"\"\"\n",
    "    # arrays to record session\n",
    "    states, actions, rewards = [], [], []\n",
    "\n",
    "    s = env.reset()[0]\n",
    "\n",
    "    for t in range(t_max):\n",
    "        # action probabilities array aka pi(a|s)\n",
    "        action_probs = predict_probs(np.array([s]))[0]\n",
    "\n",
    "        # Sample action with given probabilities.\n",
    "        a = <YOUR CODE>\n",
    "\n",
    "        new_s, r, terminated, truncated, info = env.step(a)\n",
    "\n",
    "        # record session history to train later\n",
    "        states.append(s)\n",
    "        actions.append(a)\n",
    "        rewards.append(r)\n",
    "\n",
    "        s = new_s\n",
    "        if terminated or truncated:\n",
    "            break\n",
    "\n",
    "    return states, actions, rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test it\n",
    "states, actions, rewards = generate_session(env)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computing cumulative rewards\n",
    "\n",
    "$$\n",
    "\\begin{align*}\n",
    "G_t &= r_t + \\gamma r_{t + 1} + \\gamma^2 r_{t + 2} + \\ldots \\\\\n",
    "&= \\sum_{i = t}^T \\gamma^{i - t} r_i \\\\\n",
    "&= r_t + \\gamma * G_{t + 1}\n",
    "\\end{align*}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cumulative_rewards(rewards,  # rewards at each step\n",
    "                           gamma=0.99  # discount for reward\n",
    "                           ):\n",
    "    \"\"\"\n",
    "    Take a list of immediate rewards r(s,a) for the whole session \n",
    "    and compute cumulative returns (a.k.a. G(s,a) in Sutton '16).\n",
    "    \n",
    "    G_t = r_t + gamma*r_{t+1} + gamma^2*r_{t+2} + ...\n",
    "\n",
    "    A simple way to compute cumulative rewards is to iterate from the last\n",
    "    to the first timestep and compute G_t = r_t + gamma*G_{t+1} recurrently\n",
    "\n",
    "    You must return an array/list of cumulative rewards with as many elements as in the initial rewards.\n",
    "    \"\"\"\n",
    "    <YOUR CODE>\n",
    "    return <YOUR CODE: array of cumulative rewards>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_cumulative_rewards(rewards)\n",
    "assert len(get_cumulative_rewards(list(range(100)))) == 100\n",
    "assert np.allclose(\n",
    "    get_cumulative_rewards([0, 0, 1, 0, 0, 1, 0], gamma=0.9),\n",
    "    [1.40049, 1.5561, 1.729, 0.81, 0.9, 1.0, 0.0])\n",
    "assert np.allclose(\n",
    "    get_cumulative_rewards([0, 0, 1, -2, 3, -4, 0], gamma=0.5),\n",
    "    [0.0625, 0.125, 0.25, -1.5, 1.0, -4.0, 0.0])\n",
    "assert np.allclose(\n",
    "    get_cumulative_rewards([0, 0, 1, 2, 3, 4, 0], gamma=0),\n",
    "    [0, 0, 1, 2, 3, 4, 0])\n",
    "print(\"looks good!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Loss function and updates\n",
    "\n",
    "We now need to define objective and update over policy gradient.\n",
    "\n",
    "Our objective function is\n",
    "\n",
    "$$ J \\approx  { 1 \\over N } \\sum_{s_i,a_i} G(s_i,a_i) $$\n",
    "\n",
    "REINFORCE defines a way to compute the gradient of the expected reward with respect to policy parameters. The formula is as follows:\n",
    "\n",
    "$$ \\nabla_\\theta \\hat J(\\theta) \\approx { 1 \\over N } \\sum_{s_i, a_i} \\nabla_\\theta \\log \\pi_\\theta (a_i \\mid s_i) \\cdot G_t(s_i, a_i) $$\n",
    "\n",
    "We can abuse PyTorch's capabilities for automatic differentiation by defining our objective function as follows:\n",
    "\n",
    "$$ \\hat J(\\theta) \\approx { 1 \\over N } \\sum_{s_i, a_i} \\log \\pi_\\theta (a_i \\mid s_i) \\cdot G_t(s_i, a_i) $$\n",
    "\n",
    "When you compute the gradient of that function with respect to network weights $\\theta$, it will become exactly the policy gradient."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Your code: define optimizers\n",
    "optimizer = torch.optim.Adam(model.parameters(), 1e-3)\n",
    "\n",
    "\n",
    "def train_on_session(states, actions, rewards, gamma=0.99, entropy_coef=1e-2):\n",
    "    \"\"\"\n",
    "    Takes a sequence of states, actions and rewards produced by generate_session.\n",
    "    Updates agent's weights by following the policy gradient above.\n",
    "    Please use Adam optimizer with default parameters.\n",
    "    \"\"\"\n",
    "\n",
    "    # cast everything into torch tensors\n",
    "    states = torch.tensor(states, dtype=torch.float32)\n",
    "    actions = torch.tensor(actions, dtype=torch.int64)\n",
    "    cumulative_returns = np.array(get_cumulative_rewards(rewards, gamma))\n",
    "    cumulative_returns = torch.tensor(cumulative_returns, dtype=torch.float32)\n",
    "\n",
    "    # predict logits, probas and log-probas using an agent.\n",
    "    logits = model(states)\n",
    "    probs = nn.functional.softmax(logits, -1)\n",
    "    log_probs = nn.functional.log_softmax(logits, -1)\n",
    "\n",
    "    assert all(isinstance(v, torch.Tensor) for v in [logits, probs, log_probs]), \\\n",
    "        \"please use compute using torch tensors and don't use predict_probs function\"\n",
    "\n",
    "    # select log-probabilities for chosen actions, log pi(a_i|s_i)\n",
    "    log_probs_for_actions = torch.sum(\n",
    "        log_probs * F.one_hot(actions, env.action_space.n), dim=1)\n",
    "   \n",
    "    # Compute loss here. Don't forgen entropy regularization with `entropy_coef` \n",
    "    entropy = <YOUR CODE>\n",
    "    loss = <YOUR CODE>\n",
    "\n",
    "    # Gradient descent step\n",
    "    <YOUR CODE>\n",
    "\n",
    "    # technical: return session rewards to print them later\n",
    "    return np.sum(rewards)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The actual training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(100):\n",
    "    rewards = [train_on_session(*generate_session(env)) for _ in range(100)]  # generate new sessions\n",
    "    \n",
    "    print(\"mean reward: %.3f\" % (np.mean(rewards)))\n",
    "    \n",
    "    if np.mean(rewards) > 500:\n",
    "        print(\"You Win!\")  # but you can train even further\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Results & video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Record sessions\n",
    "\n",
    "from gymnasium.wrappers import RecordVideo\n",
    "\n",
    "with gym.make(\"CartPole-v1\", render_mode=\"rgb_array\") as env, RecordVideo(\n",
    "    env=env, video_folder=\"./videos\"\n",
    ") as env_monitor:\n",
    "    sessions = [generate_session(env_monitor) for _ in range(10)]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show video. This may not work in some setups. If it doesn't\n",
    "# work for you, you can download the videos and view them locally.\n",
    "\n",
    "from pathlib import Path\n",
    "from base64 import b64encode\n",
    "from IPython.display import HTML\n",
    "\n",
    "video_paths = sorted([s for s in Path('videos').iterdir() if s.suffix == '.mp4'])\n",
    "video_path = video_paths[-1]  # You can also try other indices\n",
    "\n",
    "if 'google.colab' in sys.modules:\n",
    "    # https://stackoverflow.com/a/57378660/1214547\n",
    "    with video_path.open('rb') as fp:\n",
    "        mp4 = fp.read()\n",
    "    data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n",
    "else:\n",
    "    data_url = str(video_path)\n",
    "\n",
    "HTML(\"\"\"\n",
    "<video width=\"640\" height=\"480\" controls>\n",
    "  <source src=\"{}\" type=\"video/mp4\">\n",
    "</video>\n",
    "\"\"\".format(data_url))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
